# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import argparse

import torch
from torch.autograd import Variable

from profiler.profiling_utils import Profile, profile_print
from pyro.distributions import (
    Bernoulli,
    Beta,
    Categorical,
    Cauchy,
    Dirichlet,
    Exponential,
    Gamma,
    LogNormal,
    Normal,
    OneHotCategorical,
    Poisson,
    Uniform,
)


def T(arr):
    return Variable(torch.DoubleTensor(arr))


TOOL = "timeit"
TOOL_CFG = {}
DISTRIBUTIONS = {
    "Bernoulli": (Bernoulli, {"probs": T([0.3, 0.3, 0.3, 0.3])}),
    "Beta": (
        Beta,
        {
            "concentration1": T([2.4, 2.4, 2.4, 2.4]),
            "concentration0": T([3.2, 3.2, 3.2, 3.2]),
        },
    ),
    "Categorical": (Categorical, {"probs": T([0.1, 0.3, 0.4, 0.2])}),
    "OneHotCategorical": (OneHotCategorical, {"probs": T([0.1, 0.3, 0.4, 0.2])}),
    "Dirichlet": (Dirichlet, {"concentration": T([2.4, 3, 6, 6])}),
    "Normal": (
        Normal,
        {"loc": T([0.5, 0.5, 0.5, 0.5]), "scale": T([1.2, 1.2, 1.2, 1.2])},
    ),
    "LogNormal": (
        LogNormal,
        {"loc": T([0.5, 0.5, 0.5, 0.5]), "scale": T([1.2, 1.2, 1.2, 1.2])},
    ),
    "Cauchy": (
        Cauchy,
        {"loc": T([0.5, 0.5, 0.5, 0.5]), "scale": T([1.2, 1.2, 1.2, 1.2])},
    ),
    "Exponential": (Exponential, {"rate": T([5.5, 3.2, 4.1, 5.6])}),
    "Poisson": (Poisson, {"rate": T([5.5, 3.2, 4.1, 5.6])}),
    "Gamma": (
        Gamma,
        {"concentration": T([2.4, 2.4, 2.4, 2.4]), "rate": T([3.2, 3.2, 3.2, 3.2])},
    ),
    "Uniform": (Uniform, {"low": T([0, 0, 0, 0]), "high": T([4, 4, 4, 4])}),
}


def get_tool():
    return TOOL


def get_tool_cfg():
    return TOOL_CFG


@Profile(
    tool=get_tool,
    tool_cfg=get_tool_cfg,
    fn_id=lambda dist, batch_size, *args, **kwargs: "sample_"
    + dist.dist_class.__name__
    + "_N="
    + str(batch_size),
)
def sample(dist, batch_size):
    return dist.sample(sample_shape=(batch_size,))


@Profile(
    tool=get_tool,
    tool_cfg=get_tool_cfg,
    fn_id=lambda dist, batch, *args, **kwargs: "log_prob_"  #
    + dist.dist_class.__name__
    + "_N="
    + str(batch.size()[0]),
)
def log_prob(dist, batch):
    return dist.log_prob(batch)


def run_with_tool(tool, dists, batch_sizes):
    column_widths, field_format, template = None, None, None
    if tool == "timeit":
        profile_cols = 2 * len(batch_sizes)
        column_widths = [14] * (profile_cols + 1)
        field_format = [None] + ["{:.6f}"] * profile_cols
        template = "column"
    elif tool == "cprofile":
        column_widths = [14, 80]
        template = "row"
    with profile_print(column_widths, field_format, template) as out:
        column_headers = []
        for size in batch_sizes:
            column_headers += [
                "SAMPLE (N=" + str(size) + ")",
                "LOG_PROB (N=" + str(size) + ")",
            ]
        out.header(["DISTRIBUTION"] + column_headers)
        for dist_name in dists:
            Dist, params = DISTRIBUTIONS[dist_name]
            result_row = [dist_name]
            dist = Dist(**params)
            for size in batch_sizes:
                sample_result, sample_prof = sample(dist, batch_size=size)
                _, logpdf_prof = log_prob(dist, sample_result)
                result_row += [sample_prof, logpdf_prof]
            out.push(result_row)


def set_tool_cfg(args):
    global TOOL, TOOL_CFG
    TOOL = args.tool
    tool_cfg = {}
    if args.tool == "timeit":
        repeat = 5
        if args.repeat is not None:
            repeat = args.repeat
        tool_cfg = {"repeat": repeat}
    TOOL_CFG = tool_cfg


def main():
    parser = argparse.ArgumentParser(
        description="Profiling distributions library using various" "tools."
    )
    parser.add_argument(
        "--tool",
        nargs="?",
        default="timeit",
        help="Profile using tool. One of following should be specified:"
        ' ["timeit", "cprofile"]',
    )
    parser.add_argument(
        "--batch_sizes",
        nargs="*",
        type=int,
        help="Batch size of tensor - max of 4 values allowed. "
        "Default = [10000, 100000]",
    )
    parser.add_argument(
        "--dists",
        nargs="*",
        type=str,
        help="Run tests on distributions. One or more of following distributions "
        'are supported: ["bernoulli, "beta", "categorical", "dirichlet", '
        '"normal", "lognormal", "halfcauchy", "cauchy", "exponential", '
        '"poisson", "one_hot_categorical", "gamma", "uniform"] '
        "Default - Run profiling on all distributions",
    )
    parser.add_argument(
        "--repeat",
        nargs="?",
        default=5,
        type=int,
        help='When profiling using "timeit", the number of repetitions to '
        "use for the profiled function. default=5. The minimum value "
        "is reported.",
    )
    args = parser.parse_args()
    set_tool_cfg(args)
    dists = args.dists
    batch_sizes = args.batch_sizes
    if not args.batch_sizes:
        batch_sizes = [10000, 100000]
    if len(batch_sizes) >= 4:
        raise ValueError("Max of 4 batch sizes can be specified.")
    if not dists:
        dists = sorted(DISTRIBUTIONS.keys())
    run_with_tool(args.tool, dists, batch_sizes)


if __name__ == "__main__":
    main()
